#!/usr/bin/env python
#
# Copyright 2016-present the Material Components for iOS authors. All Rights
# Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import sys
from collections import namedtuple

# Prefix of all MDC umbrella headers.
UMBRELLA_HEADER_PREFIX = 'Material'

# Import prefixes that should not be checked.
IMPORT_FILTER_PREFIXES = (
    # Objective-C imports
    'Availability.h',
    'CoreFoundation/',
    'CoreGraphics/',
    'Foundation/',
    'QuartzCore/',
    'UIKit/',
    'objc/',
    # Third-party imports
    'MDF',
    'Motion',
)

# Import suffixes that should not be checked.
IMPORT_FILTER_SUFFIXES = (
    # Color themers do not have umbrella headers.
    'ColorThemer.h',
)

# Regex to match import of two kinds:
# 1) #import "Foo.h", group #1
# 2) #import <Foo.h>, group #2
IMPORT_RE = re.compile(r'#import (\"(.*\.h)\"|\<(.*\.h)\>)')


def umbrella_header(component):
    """Returns an umbrella header for the component."""
    return '{}{}.h'.format(UMBRELLA_HEADER_PREFIX, component)


def framework_umbrella_header(component):
    """Returns a framework umbrella header for the component."""
    return 'MaterialComponents/{}'.format(umbrella_header(component))


# Context in which the check is executed.
# @checked_file - the name of the file that is checked.
# @module_files - list of files in currently analyzed sub-component.
CheckContext = namedtuple('CheckContext', ['checked_file', 'module_files'])


class ImportChecks(object):
    """A class containing factory methods for various checks and filters."""

    @staticmethod
    def import_filter():
        """Returns a function to filter Objective-C and third-party framework imports."""

        def check(import_str, _):
            filter = import_str.startswith(IMPORT_FILTER_PREFIXES) or \
                     import_str.endswith(IMPORT_FILTER_SUFFIXES)
            return filter, None

        return check

    @staticmethod
    def module_import_filter():
        """Returns a function to filter same module file imports."""

        def check(import_str, check_context):
            return any(module_file.endswith(import_str) for module_file in check_context.module_files), None

        return check

    @staticmethod
    def non_umbrella_header_check():
        """Returns a function to check that an umbrella header is imported."""

        def check(import_str, check_context):
            if not import_str.startswith(UMBRELLA_HEADER_PREFIX):
                error_msg = '{} imports a non-umbrella header: {}'.format(check_context.checked_file, import_str)
                return True, error_msg
            return False, None

        return check

    @staticmethod
    def own_umbrella_header_check(umbrella_header):
        """Returns a function to check that checked component umbrella header is not used."""

        def check(import_str, check_context):
            if import_str == umbrella_header:
                return True, '{} imports an own umbrella header: {}'.format(check_context.checked_file, import_str)
            return False, None

        return check


class ImportChecker(object):
    """A class to run checks over the import strings."""

    def __init__(self, check_functions):
        self.check_functions = check_functions

    def check_import(self, import_str, check_context):
        """Runs checks over the import string and returns an error message in check failed."""
        for check in self.check_functions:
            stop, error_msg = check(import_str, check_context)
            if error_msg:
                return error_msg
            if stop:
                break

        return None

    @staticmethod
    def src_checker(component):
        """Returns an 'src' code checker."""
        return ImportChecker([
            ImportChecks.import_filter(),
            ImportChecks.own_umbrella_header_check(umbrella_header(component)),
            ImportChecks.own_umbrella_header_check(framework_umbrella_header(component)),
            ImportChecks.module_import_filter(),
            ImportChecks.non_umbrella_header_check(),
        ])

    @staticmethod
    def general_checker():
        """Returns a general purpose checker."""
        return ImportChecker([
            ImportChecks.import_filter(),
            ImportChecks.module_import_filter(),
            ImportChecks.non_umbrella_header_check(),
        ])


def imports_in_file(file_path):
    """Returns a list of import strings."""
    imports = []
    with open(file_path) as f:
        for line in f.readlines():
            match = IMPORT_RE.match(line)
            if not match:
                # No imports found in file.
                continue

            groups = match.groups()
            if groups[1]:
                imports.append(groups[1])
            elif groups[2]:
                imports.append(groups[2])

    return imports


def list_objc_files(path, skip_dirs=()):
    """Returns a list of Objective-C files in the path."""
    files_relative_paths = []
    for dirpath, dirnames, files in os.walk(path):
        for f in files:
            if os.path.basename(dirpath) in skip_dirs:
                continue
            if f.endswith(('.h', '.m')):
                files_relative_paths.append(os.path.relpath(os.path.join(dirpath, f), path))

    return files_relative_paths


def check_src_path(src_path, component_name):
    """Returns whether src_path has any import errors."""
    sub_modules = []
    submodules_have_errors = False
    for dir in os.listdir(src_path):
        # If a directory in 'src' is capitalized, treat it as a separate module.
        if os.path.isdir(os.path.join(src_path, dir)) and dir[0].isupper():
            sub_modules.append(dir)
            check_successful = check_path(os.path.join(src_path, dir))
            submodules_have_errors = submodules_have_errors or check_successful

    src_has_errors = check_path(src_path,
                                checker=ImportChecker.src_checker(component_name),
                                excludes=sub_modules)
    return src_has_errors or submodules_have_errors


def check_path(path, checker=ImportChecker.general_checker(), excludes=()):
    """Returns whether imports in files at path have import errors."""
    has_errors = False
    files = list_objc_files(path, skip_dirs=excludes)
    for f in files:
        check_context = CheckContext(checked_file=os.path.basename(f),
                                     module_files=files)

        imports = imports_in_file(os.path.join(path, f))
        for import_str in imports:
            error = checker.check_import(import_str, check_context)
            if error:
                print 'ERROR:', error
                has_errors = True

    return has_errors


def main(args):
    if len(args) < 1:
        print 'ERROR: Component path should be passed as an input parameter.'
        sys.exit(-1)

    component_path = os.path.normpath(args[0])
    component_name = os.path.basename(component_path)

    src_path = os.path.join(component_path, 'src')
    examples_path = os.path.join(component_path, 'examples')
    has_errors = check_src_path(src_path, component_name) or \
                 check_path(examples_path)

    sys.exit(-1 if has_errors else 0)


if __name__ == '__main__':
    main(sys.argv[1:])
